import numpy as np
#import mpmath
from scipy.optimize import curve_fit
import time
#import functools
from matplotlib import pyplot as plt

def FitFuncGauss(xdata, Amp, x0, y0, wx, wy, rot_ang):
#   p0: Amp, x0, y0, wx, wy, rot_ang
    x = xdata[..., 0]
    y = xdata[..., 1]
    x_rot = x*np.cos(rot_ang)-y*np.sin(rot_ang)
    y_rot = x*np.sin(rot_ang)+y*np.cos(rot_ang)
    x0 = x0*np.cos(rot_ang)-y0*np.sin(rot_ang)
    y0 = x0*np.sin(rot_ang)+y0*np.cos(rot_ang)
    F = Amp*np.exp(  -((x_rot-x0)**2)/(2*wx**2)  -  ((y_rot-y0)**2)/(2*wy**2)  )
    return F.ravel()
        
def FitFuncBose(xdata, Amp, x0, y0, wx, wy, rot_ang, A2, offset):
    x = xdata[..., 0]
    y = xdata[..., 1]
    x_rot = x*np.cos(rot_ang)-y*np.sin(rot_ang)
    y_rot = x*np.sin(rot_ang)+y*np.cos(rot_ang)
    x0 = x0*np.cos(rot_ang)-y0*np.sin(rot_ang)
    y0 = x0*np.sin(rot_ang)+y0*np.cos(rot_ang)
    z = Amp*np.exp(  -((x_rot-x0)**2)/(2*wx**2)  -  ((y_rot-y0)**2)/(2*wy**2)  )

    gamma = 2
    zn = 0
    for ind in range(20):
        zn += (z**(ind+1))/((ind+1)**gamma)
#    zn = mpmath.polylog(2, z)
    F = A2*zn+offset
    return F.ravel()

def TFPeak(xdata, Ac, x0, y0, wxc, wyc, rot_ang):
    x = xdata[..., 0]
    y = xdata[..., 1]
    x_rot = x*np.cos(rot_ang)-y*np.sin(rot_ang)
    y_rot = x*np.sin(rot_ang)+y*np.cos(rot_ang)
    x0 = x0*np.cos(rot_ang)-y0*np.sin(rot_ang)
    y0 = x0*np.sin(rot_ang)+y0*np.cos(rot_ang)
    temp = (1 - (x_rot-x0)**2/wxc**2 - (y_rot-y0)**2/wyc**2)
    temp = np.where(temp < 0, 0, temp)
    F2 = Ac*(temp**1.5)
    return F2.ravel()

def FitBimodal(xdata, Amp, x0, y0, wx, wy, rot_ang, A2, offset, wxc, wyc, Ac):
    x = xdata[..., 0]
    y = xdata[..., 1]
    x_rot = x*np.cos(rot_ang)-y*np.sin(rot_ang)
    y_rot = x*np.sin(rot_ang)+y*np.cos(rot_ang)
    x0 = x0*np.cos(rot_ang)-y0*np.sin(rot_ang)
    y0 = x0*np.sin(rot_ang)+y0*np.cos(rot_ang)
    z = Amp*np.exp(  -((x_rot-x0)**2)/(2*wx**2)  -  ((y_rot-y0)**2)/(2*wy**2)  )
    F1 = z+offset

    gamma = 2
    zn = 0
    for ind in range(20):
        zn += (z**(ind+1))/((ind+1)**gamma)
#    zn = mpmath.polylog(2, z)
    F1 = A2*zn
    
    
    temp = (1 - (x_rot-x0)**2/wxc**2 - (y_rot-y0)**2/wyc**2)
    temp = np.where(temp < 0, 0, temp)
    F2 = Ac*(temp**1.5)
#    F2 = Ac*np.max( ( 0, (1 - (x_rot-x0)**2/wxc**2 - (y_rot-y0)**2/wyc**2)**1.5 ) )
    F = F1+F2+offset
    return F.ravel()

def Main_Gauss(Od_Data, p0, x_LowBound, x_UpBound, Od_Lim, Fit_Type):
    Data_Size = np.shape(Od_Data)
    Data_Size = np.array(Data_Size)
#    if Data_Size[0]%2 == 0:
#        Data_Size = Data_Size-1
#        Od_Data = Od_Data[0:Data_Size[0], 0:Data_Size[1]]
    Od_Data1 = Od_Data.copy()
#    Od_Data1[Od_Data1 > Od_Lim] = Od_Lim
#    W = np.ones(Data_Size)
#    W = W*Od_Data1
#    InterpolationMethod = 'nearest'
    FitForOrientation = 0
    xlist = np.linspace(-np.floor(Data_Size[0]/2), np.floor(Data_Size[0]/2)-1, 2*np.floor(Data_Size[0]/2))
    ylist = np.linspace(-np.floor(Data_Size[1]/2), np.floor(Data_Size[1]/2)-1, 2*np.floor(Data_Size[1]/2))
    X, Y = np.meshgrid(xlist, ylist)
    xdata = np.zeros((np.shape(X)[0], np.shape(Y)[1], 2))
    xdata[:, :, 0] = X
    xdata[:, :, 1] = Y
    xlist = np.linspace(-np.floor(Data_Size[0]/2), np.floor(Data_Size[0]/2), Data_Size[0])
    ylist = np.linspace(-np.floor(Data_Size[1]/2), np.floor(Data_Size[1]/2), Data_Size[1])
    
    if Fit_Type == 'Gauss':
        Fit_Function = FitFuncGauss
        lb = np.array([     0, -Data_Size[0]/2,                           0, -Data_Size[1]/2,                           0, -np.pi/4])
        ub = np.array([np.inf,  Data_Size[0]/2, Data_Size[0]*Data_Size[0]/4,  Data_Size[1]/2, Data_Size[1]*Data_Size[1]/4,  np.pi/4])
        bounds = (lb, ub)
    elif Fit_Type == 'Bose':
        Fit_Function = FitFuncBose
        bounds = (x_LowBound, x_UpBound)
        if FitForOrientation == 0:
            x, pcov = curve_fit(Fit_Function, xdata, Od_Data1.ravel(), p0=p0, bounds=bounds)
        Od_Fit = Fit_Function(xdata, x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]).reshape(np.shape(X))
        xvh = np.linspace(-np.floor(Data_Size[0]/2), np.floor(Data_Size[0]/2), 2*np.floor(Data_Size[0]/2)+1)
        yvh = x[2]
        xdata_Fit = yvh*np.ones((len(xvh), 2))
        xdata_Fit[..., 0] = xvh
        h_Fit = Fit_Function(xdata_Fit, x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7])
        plt.plot(xvh, Od_Data[:, Data_Size[0]//2+1], '.', xvh, h_Fit, 'k')
        plt.pause(0.01)
    elif Fit_Type == 'TFPeak':
        Fit_Function = TFPeak
        p0TF = [p0[10], p0[1], p0[2], p0[8], p0[9], p0[5]]
        x_LowBoundTF = [x_LowBound[10], x_LowBound[1], x_LowBound[2], x_LowBound[8], x_LowBound[9], x_LowBound[5]]
        x_UpBoundTF = [x_UpBound[10], x_UpBound[1], x_UpBound[2], x_UpBound[8], x_UpBound[9], x_UpBound[5]]
        bounds = (x_LowBoundTF, x_UpBoundTF)
        start_time = time.time()
        x, pcov = curve_fit(Fit_Function, xdata, Od_Data1.ravel(), p0=p0TF, bounds=bounds)#, maxfev=600, xtol=1e-6, ftol=1e-6)
        print('time:%d\n'%(time.time()-start_time))
        Od_Hot = Od_Data.ravel() - TFPeak(xdata, x[0], x[1], x[2], x[3], x[4], x[5])
        Od_Hot = Od_Hot.reshape(np.shape(X))
        xvh = xlist
        yvh = x[2]
        xdata_Fit = yvh*np.ones((len(xvh), 2))
        plt.plot(xlist, Od_Data[:, Data_Size[0]//2+1], 'r.', xlist, Od_Hot[:, Data_Size[0]//2+1], 'k.')
        plt.pause(0.01)
        Od_Fit = Od_Hot.copy()
    elif Fit_Type == 'Bimodal':
        Fit_Function = FitBimodal
        bounds = (x_LowBound, x_UpBound)
        start_time = time.time()
        x, pcov = curve_fit(Fit_Function, xdata, Od_Data1.ravel(), p0=p0, bounds=bounds, xtol=1e-11, ftol=1e-11)
        print('time:%d\n'%(time.time()-start_time))
        
        Bimodal_Fit = Fit_Function(xdata, x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10]).reshape(np.shape(X))
        Cold_Fit = TFPeak(xdata, x[10], x[1], x[2], x[8], x[9], x[5]).reshape(np.shape(X))
        Hot_Fit = FitFuncBose(xdata, x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]).reshape(np.shape(X))
        
        Od_Hot = Od_Data - Cold_Fit
        xvh = xlist
        yvh = x[2]
        xdata_Fit = yvh*np.ones((len(xvh), 2))
        plt.plot(xlist, Od_Data[:, Data_Size[0]//2+1], 'r.', xlist, Bimodal_Fit[:, Data_Size[0]//2+1], 'k', xlist, Hot_Fit[:, Data_Size[0]//2+1], 'b')
#        plt.rc('axes', linewidth=4)
#        plt.rc('xtick', labelsize=20)
#        plt.rc('ytick', labelsize=20)
        plt.pause(0.01)
        Od_Fit = Od_Hot.copy()
        Check_p = np.array([x, x_LowBound, x_UpBound]).T

    else:
        print('Unrecognized fit type')
    return Od_Fit, Check_p
    